# coding=utf-8 # # This file is part of Hypothesis, which may be found at # https://github.com/HypothesisWorks/hypothesis-python # # Most of this work is copyright (C) 2013-2018 David R. MacIver # (david@drmaciver.com), but it contains contributions by others. See # CONTRIBUTING.rst for a full list of people who may hold copyright, and # consult the git log if you need to determine who owns an individual # contribution. # # This Source Code Form is subject to the terms of the Mozilla Public License, # v. 2.0. If a copy of the MPL was not distributed with this file, You can # obtain one at http://mozilla.org/MPL/2.0/. # # END HEADER from __future__ import division, print_function, absolute_import import heapq from enum import Enum from random import Random, getrandbits from weakref import WeakKeyDictionary from collections import defaultdict import attr from hypothesis import Phase, Verbosity, HealthCheck from hypothesis import settings as Settings from hypothesis.reporting import debug_report from hypothesis.internal.compat import Counter, ceil, hbytes, hrange, \ int_to_text, int_to_bytes, benchmark_time, int_from_bytes, \ to_bytes_sequence, unicode_safe_repr from hypothesis.utils.conventions import UniqueIdentifier from hypothesis.internal.healthcheck import fail_health_check from hypothesis.internal.conjecture.data import MAX_DEPTH, Status, \ StopTest, ConjectureData from hypothesis.internal.conjecture.minimizer import minimize, minimize_int # Tell pytest to omit the body of this module from tracebacks # http://doc.pytest.org/en/latest/example/simple.html#writing-well-integrated-assertion-helpers __tracebackhide__ = True HUNG_TEST_TIME_LIMIT = 5 * 60 @attr.s class HealthCheckState(object): valid_examples = attr.ib(default=0) invalid_examples = attr.ib(default=0) overrun_examples = attr.ib(default=0) draw_times = attr.ib(default=attr.Factory(list)) class ExitReason(Enum): max_examples = 0 max_iterations = 1 timeout = 2 max_shrinks = 3 finished = 4 flaky = 5 class RunIsComplete(Exception): pass class ConjectureRunner(object): def __init__( self, test_function, settings=None, random=None, database_key=None, ): self._test_function = test_function self.settings = settings or Settings() self.shrinks = 0 self.call_count = 0 self.event_call_counts = Counter() self.valid_examples = 0 self.start_time = benchmark_time() self.random = random or Random(getrandbits(128)) self.database_key = database_key self.status_runtimes = {} self.all_drawtimes = [] self.all_runtimes = [] self.events_to_strings = WeakKeyDictionary() self.target_selector = TargetSelector(self.random) # Tree nodes are stored in an array to prevent heavy nesting of data # structures. Branches are dicts mapping bytes to child nodes (which # will in general only be partially populated). Leaves are # ConjectureData objects that have been previously seen as the result # of following that path. self.tree = [{}] # A node is dead if there is nothing left to explore past that point. # Recursively, a node is dead if either it is a leaf or every byte # leads to a dead node when starting from here. self.dead = set() # We rewrite the byte stream at various points during parsing, to one # that will produce an equivalent result but is in some sense more # canonical. We keep track of these so that when walking the tree we # can identify nodes where the exact byte value doesn't matter and # treat all bytes there as equivalent. This significantly reduces the # size of the search space and removes a lot of redundant examples. # Maps tree indices where to the unique byte that is valid at that # point. Corresponds to data.write() calls. self.forced = {} # Maps tree indices to the maximum byte that is valid at that point. # Currently this is only used inside draw_bits, but it potentially # could get used elsewhere. self.capped = {} # Where a tree node consists of the beginning of a block we track the # size of said block. This allows us to tell when an example is too # short even if it goes off the unexplored region of the tree - if it # is at the beginning of a block of size 4 but only has 3 bytes left, # it's going to overrun the end of the buffer regardless of the # buffer contents. self.block_sizes = {} self.interesting_examples = {} self.covering_examples = {} self.shrunk_examples = set() self.tag_intern_table = {} self.health_check_state = None self.used_examples_from_database = False def __tree_is_exhausted(self): return 0 in self.dead def test_function(self, data): if benchmark_time() - self.start_time >= HUNG_TEST_TIME_LIMIT: fail_health_check(self.settings, ( 'Your test has been running for at least five minutes. This ' 'is probably not what you intended, so by default Hypothesis ' 'turns it into an error.' ), HealthCheck.hung_test) self.call_count += 1 try: self._test_function(data) data.freeze() except StopTest as e: if e.testcounter != data.testcounter: self.save_buffer(data.buffer) raise e except BaseException: self.save_buffer(data.buffer) raise finally: data.freeze() self.note_details(data) self.target_selector.add(data) self.debug_data(data) tags = frozenset(data.tags) data.tags = self.tag_intern_table.setdefault(tags, tags) if data.status == Status.VALID: self.valid_examples += 1 for t in data.tags: existing = self.covering_examples.get(t) if ( existing is None or sort_key(data.buffer) < sort_key(existing.buffer) ): self.covering_examples[t] = data if self.database is not None: self.database.save(self.covering_key, data.buffer) if existing is not None: self.database.delete( self.covering_key, existing.buffer) tree_node = self.tree[0] indices = [] node_index = 0 for i, b in enumerate(data.buffer): indices.append(node_index) if i in data.forced_indices: self.forced[node_index] = b try: self.capped[node_index] = data.capped_indices[i] except KeyError: pass try: node_index = tree_node[b] except KeyError: node_index = len(self.tree) self.tree.append({}) tree_node[b] = node_index tree_node = self.tree[node_index] if node_index in self.dead: break for u, v in data.blocks: # This can happen if we hit a dead node when walking the buffer. # In that case we alrady have this section of the tree mapped. if u >= len(indices): break self.block_sizes[indices[u]] = v - u self.dead.update(indices[self.cap:]) if data.status != Status.OVERRUN and node_index not in self.dead: self.dead.add(node_index) self.tree[node_index] = data for j in reversed(indices): if ( len(self.tree[j]) < self.capped.get(j, 255) + 1 and j not in self.forced ): break if set(self.tree[j].values()).issubset(self.dead): self.dead.add(j) else: break if data.status == Status.INTERESTING: key = data.interesting_origin changed = False try: existing = self.interesting_examples[key] except KeyError: changed = True else: if sort_key(data.buffer) < sort_key(existing.buffer): self.shrinks += 1 self.downgrade_buffer(existing.buffer) changed = True if changed: self.save_buffer(data.buffer) self.interesting_examples[key] = data self.shrunk_examples.discard(key) if self.shrinks >= self.settings.max_shrinks: self.exit_with(ExitReason.max_shrinks) if ( self.settings.timeout > 0 and benchmark_time() >= self.start_time + self.settings.timeout ): self.exit_with(ExitReason.timeout) if not self.interesting_examples: if self.valid_examples >= self.settings.max_examples: self.exit_with(ExitReason.max_examples) if self.call_count >= max( self.settings.max_iterations, self.settings.max_examples ): self.exit_with(ExitReason.max_iterations) if self.__tree_is_exhausted(): self.exit_with(ExitReason.finished) self.record_for_health_check(data) def generate_novel_prefix(self): prefix = bytearray() node = 0 while True: assert len(prefix) < self.cap assert node not in self.dead upper_bound = self.capped.get(node, 255) + 1 try: c = self.forced[node] prefix.append(c) node = self.tree[node][c] continue except KeyError: pass c = self.random.randrange(0, upper_bound) try: next_node = self.tree[node][c] if next_node in self.dead: choices = [ b for b in hrange(upper_bound) if self.tree[node].get(b) not in self.dead ] assert choices c = self.random.choice(choices) node = self.tree[node][c] else: node = next_node prefix.append(c) except KeyError: prefix.append(c) break assert node not in self.dead return hbytes(prefix) @property def cap(self): return self.settings.buffer_size // 2 def record_for_health_check(self, data): # Once we've actually found a bug, there's no point in trying to run # health checks - they'll just mask the actually important information. if data.status == Status.INTERESTING: self.health_check_state = None state = self.health_check_state if state is None: return state.draw_times.extend(data.draw_times) if data.status == Status.VALID: state.valid_examples += 1 elif data.status == Status.INVALID: state.invalid_examples += 1 else: assert data.status == Status.OVERRUN state.overrun_examples += 1 max_valid_draws = 10 max_invalid_draws = 50 max_overrun_draws = 20 assert state.valid_examples <= max_valid_draws if state.valid_examples == max_valid_draws: self.health_check_state = None return if state.overrun_examples == max_overrun_draws: fail_health_check(self.settings, ( 'Examples routinely exceeded the max allowable size. ' '(%d examples overran while generating %d valid ones)' '. Generating examples this large will usually lead to' ' bad results. You could try setting max_size parameters ' 'on your collections and turning ' 'max_leaves down on recursive() calls.') % ( state.overrun_examples, state.valid_examples ), HealthCheck.data_too_large) if state.invalid_examples == max_invalid_draws: fail_health_check(self.settings, ( 'It looks like your strategy is filtering out a lot ' 'of data. Health check found %d filtered examples but ' 'only %d good ones. This will make your tests much ' 'slower, and also will probably distort the data ' 'generation quite a lot. You should adapt your ' 'strategy to filter less. This can also be caused by ' 'a low max_leaves parameter in recursive() calls') % ( state.invalid_examples, state.valid_examples ), HealthCheck.filter_too_much) draw_time = sum(state.draw_times) if draw_time > 1.0: fail_health_check(self.settings, ( 'Data generation is extremely slow: Only produced ' '%d valid examples in %.2f seconds (%d invalid ones ' 'and %d exceeded maximum size). Try decreasing ' "size of the data you're generating (with e.g." 'max_size or max_leaves parameters).' ) % ( state.valid_examples, draw_time, state.invalid_examples, state.overrun_examples), HealthCheck.too_slow,) def save_buffer(self, buffer): if self.settings.database is not None: key = self.database_key if key is None: return self.settings.database.save(key, hbytes(buffer)) def downgrade_buffer(self, buffer): if self.settings.database is not None: self.settings.database.move( self.database_key, self.secondary_key, buffer) @property def secondary_key(self): return b'.'.join((self.database_key, b'secondary')) @property def covering_key(self): return b'.'.join((self.database_key, b'coverage')) def note_details(self, data): runtime = max(data.finish_time - data.start_time, 0.0) self.all_runtimes.append(runtime) self.all_drawtimes.extend(data.draw_times) self.status_runtimes.setdefault(data.status, []).append(runtime) for event in set(map(self.event_to_string, data.events)): self.event_call_counts[event] += 1 def debug(self, message): with self.settings: debug_report(message) def debug_data(self, data): if self.settings.verbosity < Verbosity.debug: return buffer_parts = [u"["] for i, (u, v) in enumerate(data.blocks): if i > 0: buffer_parts.append(u" || ") buffer_parts.append( u', '.join(int_to_text(int(i)) for i in data.buffer[u:v])) buffer_parts.append(u']') status = unicode_safe_repr(data.status) if data.status == Status.INTERESTING: status = u'%s (%s)' % ( status, unicode_safe_repr(data.interesting_origin,)) self.debug(u'%d bytes %s -> %s, %s' % ( data.index, u''.join(buffer_parts), status, data.output, )) def run(self): with self.settings: try: self._run() except RunIsComplete: pass for v in self.interesting_examples.values(): self.debug_data(v) self.debug( u'Run complete after %d examples (%d valid) and %d shrinks' % (self.call_count, self.valid_examples, self.shrinks)) def _new_mutator(self): target_data = [None] def draw_new(data, n): return uniform(self.random, n) def draw_existing(data, n): return target_data[0].buffer[data.index:data.index + n] def draw_smaller(data, n): existing = target_data[0].buffer[data.index:data.index + n] r = uniform(self.random, n) if r <= existing: return r return _draw_predecessor(self.random, existing) def draw_larger(data, n): existing = target_data[0].buffer[data.index:data.index + n] r = uniform(self.random, n) if r >= existing: return r return _draw_successor(self.random, existing) def reuse_existing(data, n): choices = data.block_starts.get(n, []) if choices: i = self.random.choice(choices) return hbytes(data.buffer[i:i + n]) else: result = uniform(self.random, n) assert isinstance(result, hbytes) return result def flip_bit(data, n): buf = bytearray( target_data[0].buffer[data.index:data.index + n]) i = self.random.randint(0, n - 1) k = self.random.randint(0, 7) buf[i] ^= (1 << k) return hbytes(buf) def draw_zero(data, n): return hbytes(b'\0' * n) def draw_max(data, n): return hbytes([255]) * n def draw_constant(data, n): return hbytes([self.random.randint(0, 255)]) * n def redraw_last(data, n): u = target_data[0].blocks[-1][0] if data.index + n <= u: return target_data[0].buffer[data.index:data.index + n] else: return uniform(self.random, n) options = [ draw_new, redraw_last, redraw_last, reuse_existing, reuse_existing, draw_existing, draw_smaller, draw_larger, flip_bit, draw_zero, draw_max, draw_zero, draw_max, draw_constant, ] bits = [ self.random.choice(options) for _ in hrange(3) ] prefix = [None] def mutate_from(origin): target_data[0] = origin prefix[0] = self.generate_novel_prefix() return draw_mutated def draw_mutated(data, n): if data.index + n > len(target_data[0].buffer): result = uniform(self.random, n) else: result = self.random.choice(bits)(data, n) p = prefix[0] if data.index < len(p): start = p[data.index:data.index + n] result = start + result[len(start):] return self.__zero_bound(data, result) return mutate_from def __rewrite(self, data, result): return self.__zero_bound(data, result) def __zero_bound(self, data, result): """This tries to get the size of the generated data under control by replacing the result with zero if we are too deep or have already generated too much data. This causes us to enter "shrinking mode" there and thus reduce the size of the generated data. """ initial = len(result) if data.depth * 2 >= MAX_DEPTH or data.index >= self.cap: data.forced_indices.update( hrange(data.index, data.index + initial)) data.hit_zero_bound = True result = hbytes(initial) elif data.index + initial >= self.cap: data.hit_zero_bound = True n = self.cap - data.index data.forced_indices.update( hrange(self.cap, data.index + initial)) result = result[:n] + hbytes(initial - n) assert len(result) == initial return result @property def database(self): if self.database_key is None: return None return self.settings.database def has_existing_examples(self): return ( self.database is not None and Phase.reuse in self.settings.phases ) def reuse_existing_examples(self): """If appropriate (we have a database and have been told to use it), try to reload existing examples from the database. If there are a lot we don't try all of them. We always try the smallest example in the database (which is guaranteed to be the last failure) and the largest (which is usually the seed example which the last failure came from but we don't enforce that). We then take a random sampling of the remainder and try those. Any examples that are no longer interesting are cleared out. """ if self.has_existing_examples(): self.debug('Reusing examples from database') # We have to do some careful juggling here. We have two database # corpora: The primary and secondary. The primary corpus is a # small set of minimized examples each of which has at one point # demonstrated a distinct bug. We want to retry all of these. # We also have a secondary corpus of examples that have at some # point demonstrated interestingness (currently only ones that # were previously non-minimal examples of a bug, but this will # likely expand in future). These are a good source of potentially # interesting examples, but there are a lot of them, so we down # sample the secondary corpus to a more manageable size. corpus = sorted( self.settings.database.fetch(self.database_key), key=sort_key ) desired_size = max(2, ceil(0.1 * self.settings.max_examples)) for extra_key in [self.secondary_key, self.covering_key]: if len(corpus) < desired_size: extra_corpus = list( self.settings.database.fetch(extra_key), ) shortfall = desired_size - len(corpus) if len(extra_corpus) <= shortfall: extra = extra_corpus else: extra = self.random.sample(extra_corpus, shortfall) extra.sort(key=sort_key) corpus.extend(extra) self.used_examples_from_database = len(corpus) > 0 for existing in corpus: last_data = ConjectureData.for_buffer(existing) try: self.test_function(last_data) finally: if last_data.status != Status.INTERESTING: self.settings.database.delete( self.database_key, existing) self.settings.database.delete( self.secondary_key, existing) def exit_with(self, reason): self.exit_reason = reason raise RunIsComplete() def generate_new_examples(self): if Phase.generate not in self.settings.phases: return zero_data = self.cached_test_function( hbytes(self.settings.buffer_size)) if zero_data.status == Status.OVERRUN or ( zero_data.status == Status.VALID and len(zero_data.buffer) * 2 > self.settings.buffer_size ): fail_health_check( self.settings, 'The smallest natural example for your test is extremely ' 'large. This makes it difficult for Hypothesis to generate ' 'good examples, especially when trying to reduce failing ones ' 'at the end. Consider reducing the size of your data if it is ' 'of a fixed size. You could also fix this by improving how ' 'your data shrinks (see https://hypothesis.readthedocs.io/en/' 'latest/data.html#shrinking for details), or by introducing ' 'default values inside your strategy. e.g. could you replace ' 'some arguments with their defaults by using ' 'one_of(none(), some_complex_strategy)?', HealthCheck.large_base_example ) if self.settings.perform_health_check: self.health_check_state = HealthCheckState() count = 0 while not self.interesting_examples and ( count < 10 or self.health_check_state is not None ): prefix = self.generate_novel_prefix() def draw_bytes(data, n): if data.index < len(prefix): result = prefix[data.index:data.index + n] if len(result) < n: result += uniform(self.random, n - len(result)) else: result = uniform(self.random, n) return self.__zero_bound(data, result) targets_found = len(self.covering_examples) last_data = ConjectureData( max_length=self.settings.buffer_size, draw_bytes=draw_bytes ) self.test_function(last_data) last_data.freeze() if len(self.covering_examples) > targets_found: count = 0 else: count += 1 mutations = 0 mutator = self._new_mutator() zero_bound_queue = [] while not self.interesting_examples: if zero_bound_queue: # Whenever we generated an example and it hits a bound # which forces zero blocks into it, this creates a weird # distortion effect by making certain parts of the data # stream (especially ones to the right) much more likely # to be zero. We fix this by redistributing the generated # data by shuffling it randomly. This results in the # zero data being spread evenly throughout the buffer. # Hopefully the shrinking this causes will cause us to # naturally fail to hit the bound. # If it doesn't then we will queue the new version up again # (now with more zeros) and try again. overdrawn = zero_bound_queue.pop() buffer = bytearray(overdrawn.buffer) # These will have values written to them that are different # from what's in them anyway, so the value there doesn't # really "count" for distributional purposes, and if we # leave them in then they can cause the fraction of non # zero bytes to increase on redraw instead of decrease. for i in overdrawn.forced_indices: buffer[i] = 0 self.random.shuffle(buffer) buffer = hbytes(buffer) def draw_bytes(data, n): result = buffer[data.index:data.index + n] if len(result) < n: result += hbytes(n - len(result)) return self.__rewrite(data, result) data = ConjectureData( draw_bytes=draw_bytes, max_length=self.settings.buffer_size, ) self.test_function(data) data.freeze() else: target, origin = self.target_selector.select() mutations += 1 targets_found = len(self.covering_examples) data = ConjectureData( draw_bytes=mutator(origin), max_length=self.settings.buffer_size ) self.test_function(data) data.freeze() if ( data.status > origin.status or len(self.covering_examples) > targets_found ): mutations = 0 elif ( data.status < origin.status or not self.target_selector.has_tag(target, data) or mutations >= 10 ): # Cap the variations of a single example and move on to # an entirely fresh start. Ten is an entirely arbitrary # constant, but it's been working well for years. mutations = 0 mutator = self._new_mutator() if getattr(data, 'hit_zero_bound', False): zero_bound_queue.append(data) mutations += 1 def _run(self): self.start_time = benchmark_time() self.reuse_existing_examples() self.generate_new_examples() self.shrink_interesting_examples() self.exit_with(ExitReason.finished) def shrink_interesting_examples(self): """If we've found interesting examples, try to replace each of them with a minimal interesting example with the same interesting_origin. We may find one or more examples with a new interesting_origin during the shrink process. If so we shrink these too. """ if ( Phase.shrink not in self.settings.phases or not self.interesting_examples ): return for prev_data in sorted( self.interesting_examples.values(), key=lambda d: sort_key(d.buffer) ): assert prev_data.status == Status.INTERESTING data = ConjectureData.for_buffer(prev_data.buffer) self.test_function(data) if data.status != Status.INTERESTING: self.exit_with(ExitReason.flaky) self.clear_secondary_key() while len(self.shrunk_examples) < len(self.interesting_examples): target, example = min([ (k, v) for k, v in self.interesting_examples.items() if k not in self.shrunk_examples], key=lambda kv: (sort_key(kv[1].buffer), sort_key(repr(kv[0]))), ) self.debug('Shrinking %r' % (target,)) def predicate(d): if d.status < Status.INTERESTING: return False return d.interesting_origin == target self.shrink(example, predicate) self.shrunk_examples.add(target) def clear_secondary_key(self): if self.has_existing_examples(): # If we have any smaller examples in the secondary corpus, now is # a good time to try them to see if they work as shrinks. They # probably won't, but it's worth a shot and gives us a good # opportunity to clear out the database. # It's not worth trying the primary corpus because we already # tried all of those in the initial phase. corpus = sorted( self.settings.database.fetch(self.secondary_key), key=sort_key ) cap = max( sort_key(v.buffer) for v in self.interesting_examples.values() ) for c in corpus: if sort_key(c) >= cap: break else: data = self.cached_test_function(c) if ( data.status != Status.INTERESTING or self.interesting_examples[data.interesting_origin] is not data ): self.settings.database.delete( self.secondary_key, c) def shrink(self, example, predicate): s = self.new_shrinker(example, predicate) s.shrink() return s.shrink_target def new_shrinker(self, example, predicate): return Shrinker(self, example, predicate) def prescreen_buffer(self, buffer): """Attempt to rule out buffer as a possible interesting candidate. Returns False if we know for sure that running this buffer will not produce an interesting result. Returns True if it might (because it explores territory we have not previously tried). This is purely an optimisation to try to reduce the number of tests we run. "return True" would be a valid but inefficient implementation. """ node_index = 0 n = len(buffer) for k, b in enumerate(buffer): if node_index in self.dead: return False try: # The block size at that point provides a lower bound on how # many more bytes are required. If the buffer does not have # enough bytes to fulfill that block size then we can rule out # this buffer. if k + self.block_sizes[node_index] > n: return False except KeyError: pass try: b = self.forced[node_index] except KeyError: pass try: b = min(b, self.capped[node_index]) except KeyError: pass try: node_index = self.tree[node_index][b] except KeyError: return True else: return False def cached_test_function(self, buffer): node_index = 0 for i in hrange(self.settings.buffer_size): try: c = self.forced[node_index] except KeyError: if i < len(buffer): c = buffer[i] else: c = 0 try: node_index = self.tree[node_index][c] except KeyError: break node = self.tree[node_index] if isinstance(node, ConjectureData): return node result = ConjectureData.for_buffer(buffer) self.test_function(result) return result def event_to_string(self, event): if isinstance(event, str): return event try: return self.events_to_strings[event] except KeyError: pass result = str(event) self.events_to_strings[event] = result return result def _draw_predecessor(rnd, xs): r = bytearray() any_strict = False for x in to_bytes_sequence(xs): if not any_strict: c = rnd.randint(0, x) if c < x: any_strict = True else: c = rnd.randint(0, 255) r.append(c) return hbytes(r) def _draw_successor(rnd, xs): r = bytearray() any_strict = False for x in to_bytes_sequence(xs): if not any_strict: c = rnd.randint(x, 255) if c > x: any_strict = True else: c = rnd.randint(0, 255) r.append(c) return hbytes(r) def sort_key(buffer): return (len(buffer), buffer) def uniform(random, n): return int_to_bytes(random.getrandbits(n * 8), n) class SampleSet(object): """Set data type with the ability to sample uniformly at random from it. The mechanism is that we store the set in two parts: A mapping of values to their index in an array. Sampling uniformly at random then becomes simply a matter of sampling from the array, but we can use the index for efficient lookup to add and remove values. """ __slots__ = ('__values', '__index') def __init__(self): self.__values = [] self.__index = {} def __len__(self): return len(self.__values) def __repr__(self): return 'SampleSet(%r)' % (self.__values,) def add(self, value): assert value not in self.__index # Adding simply consists of adding the value to the end of the array # and updating the index. self.__index[value] = len(self.__values) self.__values.append(value) def remove(self, value): # To remove a value we first remove it from the index. But this leaves # us with the value still in the array, so we have to fix that. We # can't simply remove the value from the array, as that would a) Be an # O(n) operation and b) Leave the index completely wrong for every # value after that index. # So what we do is we take the last element of the array and place it # in the position of the value we just deleted (if the value was not # already the last element of the array. If it was then we don't have # to do anything extra). This reorders the array, but that's OK because # we don't care about its order, we just need to sample from it. i = self.__index.pop(value) last = self.__values.pop() if i < len(self.__values): self.__values[i] = last self.__index[last] = i def choice(self, random): return random.choice(self.__values) class Negated(object): __slots__ = ('tag',) def __init__(self, tag): self.tag = tag NEGATED_CACHE = {} def negated(tag): try: return NEGATED_CACHE[tag] except KeyError: result = Negated(tag) NEGATED_CACHE[tag] = result return result universal = UniqueIdentifier('universal') class TargetSelector(object): """Data structure for selecting targets to use for mutation. The goal is to do a good job of exploiting novelty in examples without getting too obsessed with any particular novel factor. Roughly speaking what we want to do is give each distinct coverage target equal amounts of time. However some coverage targets may be harder to fuzz than others, or may only appear in a very small minority of examples, so we don't want to let those dominate the testing. Targets are selected according to the following rules: 1. We ideally want valid examples as our starting point. We ignore interesting examples entirely, and other than that we restrict ourselves to the best example status we've seen so far. If we've only seen OVERRUN examples we use those. If we've seen INVALID but not VALID examples we use those. Otherwise we use VALID examples. 2. Among the examples we've seen with the right status, when asked to select a target, we select a coverage target and return that along with an example exhibiting that target uniformly at random. Coverage target selection proceeds as follows: 1. Whenever we return an example from select, we update the usage count of each of its tags. 2. Whenever we see an example, we add it to the list of examples for all of its tags. 3. When selecting a tag, we select one with a minimal usage count. Among those of minimal usage count we select one with the fewest examples. Among those, we select one uniformly at random. This has the following desirable properties: 1. When two coverage targets are intrinsically linked (e.g. when you have multiple lines in a conditional so that either all or none of them will be covered in a conditional) they are naturally deduplicated. 2. Popular coverage targets will largely be ignored for considering what test to run - if every example exhibits a coverage target, picking an example because of that target is rather pointless. 3. When we discover new coverage targets we immediately exploit them until we get to the point where we've spent about as much time on them as the existing targets. 4. Among the interesting deduplicated coverage targets we essentially round-robin between them, but with a more consistent distribution than uniformly at random, which is important particularly for short runs. """ def __init__(self, random): self.random = random self.best_status = Status.OVERRUN self.reset() def reset(self): self.examples_by_tags = defaultdict(list) self.tag_usage_counts = Counter() self.tags_by_score = defaultdict(SampleSet) self.scores_by_tag = {} self.scores = [] self.mutation_counts = 0 self.example_counts = 0 self.non_universal_tags = set() self.universal_tags = None def add(self, data): if data.status == Status.INTERESTING: return if data.status < self.best_status: return if data.status > self.best_status: self.best_status = data.status self.reset() if self.universal_tags is None: self.universal_tags = set(data.tags) else: not_actually_universal = self.universal_tags - data.tags for t in not_actually_universal: self.universal_tags.remove(t) self.non_universal_tags.add(t) self.examples_by_tags[t] = list( self.examples_by_tags[universal] ) new_tags = data.tags - self.non_universal_tags for t in new_tags: self.non_universal_tags.add(t) self.examples_by_tags[negated(t)] = list( self.examples_by_tags[universal] ) self.example_counts += 1 for t in self.tags_for(data): self.examples_by_tags[t].append(data) self.rescore(t) def has_tag(self, tag, data): if tag is universal: return True if isinstance(tag, Negated): return tag.tag not in data.tags return tag in data.tags def tags_for(self, data): yield universal for t in data.tags: yield t for t in self.non_universal_tags: if t not in data.tags: yield negated(t) def rescore(self, tag): new_score = ( self.tag_usage_counts[tag], len(self.examples_by_tags[tag])) try: old_score = self.scores_by_tag[tag] except KeyError: pass else: self.tags_by_score[old_score].remove(tag) self.scores_by_tag[tag] = new_score sample = self.tags_by_score[new_score] if len(sample) == 0: heapq.heappush(self.scores, new_score) sample.add(tag) def select_tag(self): while True: peek = self.scores[0] sample = self.tags_by_score[peek] if len(sample) == 0: heapq.heappop(self.scores) else: return sample.choice(self.random) def select_example_for_tag(self, t): return self.random.choice(self.examples_by_tags[t]) def select(self): t = self.select_tag() self.mutation_counts += 1 result = self.select_example_for_tag(t) assert self.has_tag(t, result) for s in self.tags_for(result): self.tag_usage_counts[s] += 1 self.rescore(s) return t, result class Shrinker(object): """A shrinker is a child object of a ConjectureRunner which is designed to manage the associated state of a particular shrink problem. Currently the only shrink problem we care about is "interesting and with a particular interesting_origin", but this is abstracted into a general purpose predicate for more flexibility later - e.g. we are likely to want to shrink with respect to a particular coverage target later. Data with a status < VALID may be assumed not to satisfy the predicate. The expected usage pattern is that this is only ever called from within the engine. """ def __init__(self, engine, initial, predicate): """Create a shrinker for a particular engine, with a given starting point and predicate. When shrink() is called it will attempt to find an example for which predicate is True and which is strictly smaller than initial. Note that initial is a ConjectureData object, and predicate takes ConjectureData objects. """ self.__engine = engine self.__predicate = predicate self.__discarding_failed = False self.__shrinking_prefixes = set() # We keep track of the current best example on the shrink_target # attribute. self.shrink_target = None self.update_shrink_target(initial) def incorporate_new_buffer(self, buffer): buffer = hbytes(buffer[:self.shrink_target.index]) assert sort_key(buffer) < sort_key(self.shrink_target.buffer) if self.shrink_target.buffer.startswith(buffer): return False if not self.__engine.prescreen_buffer(buffer): return False assert sort_key(buffer) <= sort_key(self.shrink_target.buffer) data = ConjectureData.for_buffer(buffer) self.__engine.test_function(data) return self.incorporate_test_data(data) def incorporate_test_data(self, data): if ( self.__predicate(data) and sort_key(data.buffer) < sort_key(self.shrink_target.buffer) ): self.update_shrink_target(data) self.__shrinking_block_cache = {} if data.has_discards and not self.__discarding_failed: self.remove_discarded() return True return False def cached_test_function(self, buffer): result = self.__engine.cached_test_function(buffer) self.incorporate_test_data(result) return result def debug(self, msg): self.__engine.debug(msg) def shrink(self): """Run the full set of shrinks and update shrink_target. This method is "mostly idempotent" - calling it twice is unlikely to have any effect, though it has a non-zero probability of doing so. """ # We assume that if an all-zero block of bytes is an interesting # example then we're not going to do better than that. # This might not technically be true: e.g. for integers() | booleans() # the simplest example is actually [1, 0]. Missing this case is fairly # harmless and this allows us to make various simplifying assumptions # about the structure of the data (principally that we're never # operating on a block of all zero bytes so can use non-zeroness as a # signpost of complexity). if ( not any(self.shrink_target.buffer) or self.incorporate_new_buffer(hbytes(len(self.shrink_target.buffer))) ): return self.greedy_shrink() self.escape_local_minimum() def greedy_shrink(self): """Run a full set of greedy shrinks (that is, ones that will only ever move to a better target) and update shrink_target appropriately. This method iterates to a fixed point and so is idempontent - calling it twice will have exactly the same effect as calling it once. """ run_expensive_shrinks = False prev = None while prev is not self.shrink_target: prev = self.shrink_target # We reset our tracking of what changed at the beginning of the # loop so that we don't get distracted by things that change once # and then are stable thereafter. self.clear_change_tracking() self.remove_discarded() self.adaptive_example_deletion() self.zero_draws() self.minimize_duplicated_blocks() self.minimize_individual_blocks() self.reorder_blocks() self.lower_dependent_block_pairs() self.lower_common_block_offset() # Passes after this point are expensive: Prior to here they should # all involve no more than about n log(n) shrinks, but after here # they may be quadratic or worse. Running all of the passes until # they make no changes is important for correctness, but nothing # says we have to run all of them on each run! So if the fast # passes still seem to be making useful changes, we restart the # loop here and give them another go. # To avoid the case where the expensive shrinks unlock a trivial # change in one of the previous passes causing this to become much # more expensive by doubling the number of times we have to run # them to get to run the expensive passes again, we make this # decision "sticky" - once it's been useful to run the expensive # changes at least once, we always run them. if prev is self.shrink_target: run_expensive_shrinks = True if not run_expensive_shrinks: continue self.shrink_offset_pairs() self.interval_deletion_with_block_lowering() self.pass_to_interval() self.reorder_bytes() @property def blocks(self): return self.shrink_target.blocks @property def intervals(self): if self.__intervals is None: target = self.shrink_target intervals = set(target.blocks) intervals.add((0, target.index)) intervals.update( (ex.start, ex.end) for ex in target.examples if ex.start < ex.end ) intervals_by_level = {} for ex in target.examples: if ex.start < ex.end: intervals_by_level.setdefault(ex.depth, []).append(ex) for l in intervals_by_level.values(): for e1, e2 in zip(l, l[1:]): if ( not (e1.discarded or e2.discarded) and e1.end == e2.start ): intervals.add((e1.start, e2.end)) for i in hrange(len(target.blocks) - 1): intervals.add((target.blocks[i][0], target.blocks[i + 1][1])) # Intervals are sorted as longest first, then by interval start. self.__intervals = tuple(sorted( set(intervals), key=lambda se: (se[0] - se[1], se[0]) )) return self.__intervals def zero_draws(self): """Attempt to replace each draw call with its minimal possible value. This is intended as a fast-track to minimize whole sub-examples that don't matter as rapidly as possible. For example, suppose we had something like the following: ls = data.draw(st.lists(st.lists(st.integers()))) assert len(ls) >= 10 Then each of the elements of ls need to be minimized, and we can do that by deleting individual values from them, but we'd much rather do it fast rather than slow - deleting elements one at a time takes sum(map(len, ls)) shrinks, and ideally we'd do this in len(ls) shrinks as we try to replace each element with []. This pass does that by identifying the size of the "natural smallest" element here. It first tries replacing an entire interval with zero. This will sometimes work (e.g. when the interval is a block), but often what will happen is that there will be leftover zeros that spill over into the next example and ruin things - e.g. here if ls[0] is non-empty and we replace it with all zero, some of the extra zeros will be interpreted as terminating ls and will shrink it down to a one element list, causing the test to pass. So what we do instead is that once we've evaluated that shrink, we use the size of the intervals there to find other possible sizes that we could try replacing the interval with. In this case we'd spot that there is a one-byte interval starting at right place for ls[i] and try to replace it with that. This will successfully replace ls[i] with [] as desired. """ i = 0 while i < len(self.shrink_target.examples): ex = self.shrink_target.examples[i] buf = self.shrink_target.buffer if any(buf[ex.start:ex.end]): prefix = buf[:ex.start] suffix = buf[ex.end:] attempt = self.cached_test_function( prefix + hbytes(ex.length) + suffix ) if attempt.status == Status.VALID: replacement = attempt.examples[i] assert replacement.start == ex.start if replacement.length < ex.length: self.incorporate_new_buffer( prefix + hbytes(replacement.length) + suffix ) i += 1 def pass_to_interval(self): """Attempt to replace each interval with a subinterval. This is designed to deal with strategies that call themselves recursively. For example, suppose we had: binary_tree = st.deferred( lambda: st.one_of( st.integers(), st.tuples(binary_tree, binary_tree))) This pass guarantees that we can replace any binary tree with one of its subtrees - each of those will create an interval that the parent could validly be replaced with, and this pass will try doing that. This is pretty expensive - it takes O(len(intervals)^2) - so we run it late in the process when we've got the number of intervals as far down as possible. """ i = 0 while i < len(self.shrink_target.examples): ex = self.shrink_target.examples[i] changed = False for j in hrange(i + 1, len(self.shrink_target.examples)): child = self.shrink_target.examples[j] if child.start >= ex.end: break if child.length < ex.length: buf = self.shrink_target.buffer if self.incorporate_new_buffer( buf[:ex.start] + buf[child.start:child.end] + buf[ex.end:] ): changed = True break if not changed: i += 1 def is_shrinking_block(self, i): """Checks whether block i has been previously marked as a shrinking block. If the shrink target has changed since i was last checked, will attempt to calculate if an equivalent block in a previous shrink target was marked as shrinking. """ if not self.__shrinking_prefixes: return False try: return self.__shrinking_block_cache[i] except KeyError: pass t = self.shrink_target return self.__shrinking_block_cache.setdefault( i, t.buffer[:t.blocks[i][0]] in self.__shrinking_prefixes ) def lower_common_block_offset(self): """Sometimes we find ourselves in a situation where changes to one part of the byte stream unlock changes to other parts. Sometimes this is good, but sometimes this can cause us to exhibit exponential slow downs! e.g. suppose we had the following: m = draw(integers(min_value=0)) n = draw(integers(min_value=0)) assert abs(m - n) > 1 If this fails then we'll end up with a loop where on each iteration we reduce each of m and n by 2 - m can't go lower because of n, then n can't go lower because of m. This will take us O(m) iterations to complete, which is exponential in the data size, as we gradually zig zag our way towards zero. This can only happen if we're failing to reduce the size of the byte stream: The number of iterations that reduce the length of the byte stream is bounded by that length. So what we do is this: We keep track of which blocks are changing, and then if there's some non-zero common offset to them we try and minimize them all at once by lowering that offset. This may not work, and it definitely won't get us out of all possible exponential slow downs (an example of where it doesn't is where the shape of the blocks changes as a result of this bouncing behaviour), but it fails fast when it doesn't work and gets us out of a really nastily slow case when it does. """ if len(self.__changed_blocks) <= 1: return self.debug('Removing common block offset') current = self.shrink_target blocked = [current.buffer[u:v] for u, v in current.blocks] changed = sorted(self.__changed_blocks) ints = [int_from_bytes(blocked[i]) for i in changed] offset = min(ints) if offset == 0: return for i in hrange(len(ints)): ints[i] -= offset def reoffset(o): new_blocks = list(blocked) for i, v in zip(changed, ints): new_blocks[i] = int_to_bytes(v + o, len(blocked[i])) return self.incorporate_new_buffer(hbytes().join(new_blocks)) minimize_int(offset, reoffset) def shrink_offset_pairs(self): """Lower any two blocks offset from each other the same ammount. Before this shrink pass, two blocks explicitly offset from each other would not get minimized properly: >>> b = st.integers(0, 255) >>> find(st.tuples(b, b), lambda x: x[0] == x[1] + 1) (149,148) This expensive (O(n^2)) pass goes through every pair of non-zero blocks in the current shrink target and sees if the shrink target can be improved by applying an offset to both of them. """ self.debug('Shrinking offset pairs.') current = [self.shrink_target.buffer[u:v] for u, v in self.blocks] def int_from_block(i): return int_from_bytes(current[i]) def block_len(i): u, v = self.blocks[i] return v - u # Try reoffseting every pair def reoffset_pair(pair, o): n = len(self.blocks) # Number of blocks may have changed, need to validate valid_pair = [p for p in pair if p < n and int_from_block(p) > 0] if len(valid_pair) < 2: return m = min([int_from_block(p) for p in valid_pair]) new_blocks = [self.shrink_target.buffer[u:v] for u, v in self.blocks] for i in valid_pair: new_blocks[i] = int_to_bytes( int_from_block(i) + o - m, block_len(i)) buffer = hbytes().join(new_blocks) return self.incorporate_new_buffer(buffer) i = 0 while i < len(self.blocks): if not self.is_shrinking_block(i) and int_from_block(i) > 0: j = i + 1 while j < len(self.shrink_target.blocks): block_val = int_from_block(j) i_block_val = int_from_block(i) if not self.is_shrinking_block(j) \ and block_val > 0 and i_block_val > 0: offset = min(int_from_block(i), int_from_block(j)) # Save current before shrinking current = [self.shrink_target.buffer[u:v] for u, v in self.blocks] minimize_int( offset, lambda o: reoffset_pair((i, j), o)) j += 1 i += 1 def mark_shrinking(self, blocks): """Mark each of these blocks as a shrinking block: That is, lowering its value lexicographically may cause less data to be drawn after.""" t = self.shrink_target for i in blocks: if self.__shrinking_block_cache.get(i) is True: continue self.__shrinking_block_cache[i] = True prefix = t.buffer[:t.blocks[i][0]] self.__shrinking_prefixes.add(prefix) def clear_change_tracking(self): self.__changed_blocks.clear() def update_shrink_target(self, new_target): assert new_target.frozen if self.shrink_target is not None: if new_target.blocks != self.shrink_target.blocks: self.__changed_blocks = set() else: current = self.shrink_target.buffer new = new_target.buffer for i, (u, v) in enumerate(self.shrink_target.blocks): if ( i not in self.__changed_blocks and current[u:v] != new[u:v] ): self.__changed_blocks.add(i) else: self.__changed_blocks = set() self.shrink_target = new_target self.__shrinking_block_cache = {} self.__intervals = None def escape_local_minimum(self): """Attempt to restart the shrink process from a larger initial value in a way that allows us to escape a local minimum that the main greedy shrink process will get stuck in. The idea is that when we've completed the shrink process, we try starting it again from something reasonably near to the shrunk example that is likely to exhibit the same behaviour. We search for an example that is selected randomly among ones that are "structurally similar" to the original. If we don't find one we bail out fairly quickly as this will usually not work. If we do, we restart the shrink process from there. If this results in us finding a better final example, we do this again until it stops working. This is especially useful for things where the tendency to move complexity to the right works against us - often a generic instance of the problem is easy to shrink, but trying to reduce the size of a minimized example further is hard. For example suppose we had something like: x = data.draw(lists(integers())) y = data.draw(lists(integers(), min_size=len(x), max_size=len(x))) assert not (any(x) and any(y)) Then this could shrink to something like [0, 1], [0, 1]. Attempting to shrink this further by deleting an element of x would result in losing the last element of y, and the test would start passing. But if we were to replace this with [a, b], [c, d] with c != 0 then deleting a or b would work. """ count = 0 while count < 10: count += 1 self.debug('Retrying from random restart') attempt_buf = bytearray(self.shrink_target.buffer) # We use the shrinking information to identify the # structural locations in the byte stream - if lowering # the block would result in changing the size of the # example, changing it here is too likely to break whatever # it was caused the behaviour we're trying to shrink. # Everything non-structural, we redraw uniformly at random. for i, (u, v) in enumerate(self.blocks): if not self.is_shrinking_block(i): attempt_buf[u:v] = uniform(self.__engine.random, v - u) attempt = self.cached_test_function(attempt_buf) if self.__predicate(attempt): prev = self.shrink_target self.update_shrink_target(attempt) self.__shrinking_block_cache = {} self.greedy_shrink() if ( sort_key(self.shrink_target.buffer) < sort_key(prev.buffer) ): # We have successfully shrunk the example past where # we started from. Now we begin the whole processs # again from the new, smaller, example. count = 0 else: self.update_shrink_target(prev) self.__shrinking_block_cache = {} def try_shrinking_blocks(self, blocks, b): """Attempts to replace each block in the blocks list with b. Returns True if it succeeded (which may include some additional modifications to shrink_target). May call mark_shrinking with b if this causes a reduction in size. In current usage it is expected that each of the blocks currently have the same value, although this is not essential. Note that b must be < the block at min(blocks) or this is not a valid shrink. This method will attempt to do some small amount of work to delete data that occurs after the end of the blocks. This is useful for cases where there is some size dependency on the value of a block. The amount of work done here is relatively small - most such dependencies will be handled by the interval_deletion_with_block_lowering pass - but will be effective when there is a large amount of redundant data after the block to be lowered. """ initial_attempt = bytearray(self.shrink_target.buffer) for i in blocks: if i >= len(self.blocks): break u, v = self.blocks[i] n = min(v - u, len(b)) initial_attempt[v - n:v] = b[-n:] initial_data = self.cached_test_function(initial_attempt) if initial_data.status == Status.INTERESTING: return initial_data is self.shrink_target # If this produced something completely invalid we ditch it # here rather than trying to persevere. if initial_data.status < Status.VALID: return False if len(initial_data.buffer) < v: return False lost_data = len(self.shrink_target.buffer) - len(initial_data.buffer) # If this did not in fact cause the data size to shrink we # bail here because it's not worth trying to delete stuff from # the remainder. if lost_data <= 0: return False self.mark_shrinking(blocks) try_with_deleted = bytearray(initial_attempt) del try_with_deleted[v:v + lost_data] if self.incorporate_new_buffer(try_with_deleted): return True return False def remove_discarded(self): """Try removing all bytes marked as discarded. This pass is primarily to deal with data that has been ignored while doing rejection sampling - e.g. as a result of an integer range, or a filtered strategy. Such data will also be handled by the adaptive_example_deletion pass, but that pass is necessarily more conservative and will try deleting each interval individually. The common case is that all data drawn and rejected can just be thrown away immediately in one block, so this pass will be much faster than trying each one individually when it works. """ if not self.shrink_target.has_discards: return discarded = [] for ex in self.shrink_target.examples: if ex.discarded and ( not discarded or ex.start >= discarded[-1][-1] ): discarded.append((ex.start, ex.end)) attempt = bytearray(self.shrink_target.buffer) for u, v in reversed(discarded): del attempt[u:v] # We track whether discarding works because as long as it does we will # always want to run it whenever the option is available - whenever a # shrink ends up introducing new discarded data we can attempt to # delete it immediately. However if some discarded data looks essential # in some way then that would be wasteful, so we turn off the automatic # discarding if this ever fails. When this next runs explicitly, it # will reset the flag if the status changes. self.__discarding_failed = not self.incorporate_new_buffer(attempt) def adaptive_example_deletion(self): """Attempt to delete every draw call, plus some short sequences of draw calls. The only things this guarantees to attempt to delete are every draw call and every draw call plus its immediate successor (the first non-empty draw call that starts strictly after it). However if this seems to be working pretty well it will do its best to exploit that and adapt to the fact there's currently a lot that it can delete. This is the main point at which we try to lower the size of the data. e.g. if we have two successive draw calls, this will attempt to delete the first and replace it with the second. The fact that this will also try deleting the successor call is important. For example, if we have something like: while many.more(data): data.draw(stuff) This pass will attempt to delete adjacent pairs of calls to shorten the loop. """ self.debug('greedy interval deletes') i = 0 while i < len(self.shrink_target.examples): if self.shrink_target.examples[i].length == 0: i += 1 continue # Note: We do want this fixed rather than changing during this # iteration of the loop. target = self.shrink_target def try_delete_range(k): """Can we delete k non-trivial non-overlapping examples starting from i?""" stack = [] j = i while k > 0 and j < len(target.examples): ex = target.examples[j] if ex.length > 0 and ( not stack or stack[-1][1] <= ex.start ): stack.append((ex.start, ex.end)) k -= 1 j += 1 assert stack attempt = bytearray(target.buffer) for u, v in reversed(stack): del attempt[u:v] attempt = hbytes(attempt) if sort_key(attempt) >= sort_key(self.shrink_target.buffer): return False return self.incorporate_new_buffer(attempt) # This is an adaptive pass loosely modelled after timsort. If # little or nothing is deletable here then we don't try any more # deletions than the naive greedy algorithm would, but if it looks # like we have an opportunity to delete a lot then we try to do so. # What we're trying to do is to find a large k such that we can # delete k but not k + 1 draws starting from this point, and we # want to do that in O(log(k)) rather than O(k) test executions. # We try a quite careful sequence of small shrinks here before we # move on to anything big. This is because if we try to be # aggressive too early on we'll tend to find that we lose out when # the example is "nearly minimal". if try_delete_range(2): if try_delete_range(3) and try_delete_range(4): # At this point it looks like we've got a pretty good # opportunity for a long run here. We do an exponential # probe upwards to try and find some k where we can't # delete many intervals. We do this rather than choosing # that upper bound to immediately be large because we # don't really expect k to be huge. If it turns out that # it is, the subsequent example is going to be so tiny that # it doesn't really matter if we waste a bit of extra time # here. hi = 5 while try_delete_range(hi): assert hi <= len(target.examples) hi *= 2 # We now know that we can delete the first lo intervals but # not the first hi. We preserve that property while doing # a binary search to find the point at which we stop being # able to delete intervals. lo = 4 while lo + 1 < hi: mid = (lo + hi) // 2 if try_delete_range(mid): lo = mid else: hi = mid else: try_delete_range(1) # We unconditionally bump i because we have always tried deleting # one more example than we succeeded at deleting, so we expect the # next example to be undeletable. i += 1 def minimize_duplicated_blocks(self): """Find blocks that have been duplicated in multiple places and attempt to minimize all of the duplicates simultaneously. This lets us handle cases where two values can't be shrunk independently of each other but can easily be shrunk together. For example if we had something like: ls = data.draw(lists(integers())) y = data.draw(integers()) assert y not in ls Suppose we drew y = 3 and after shrinking we have ls = [3]. If we were to replace both 3s with 0, this would be a valid shrink, but if we were to replace either 3 with 0 on its own the test would start passing. It is also useful for when that duplication is accidental and the value of the blocks doesn't matter very much because it allows us to replace more values at once. """ self.debug('Simultaneous shrinking of duplicated blocks') def canon(b): i = 0 while i < len(b) and b[i] == 0: i += 1 return b[i:] counts = Counter( canon(self.shrink_target.buffer[u:v]) for u, v in self.blocks ) counts.pop(hbytes(), None) blocks = [buffer for buffer, count in counts.items() if count > 1] blocks.sort(reverse=True) blocks.sort(key=lambda b: counts[b] * len(b), reverse=True) for block in blocks: targets = [ i for i, (u, v) in enumerate(self.blocks) if canon(self.shrink_target.buffer[u:v]) == block ] # This can happen if some blocks have been lost in the previous # shrinking. if len(targets) <= 1: continue minimize( block, lambda b: self.try_shrinking_blocks(targets, b), random=self.__engine.random, full=False ) def minimize_individual_blocks(self): """Attempt to minimize each block in sequence. This is the pass that ensures that e.g. each integer we draw is a minimum value. So it's the part that guarantees that if we e.g. do x = data.draw(integers()) assert x < 10 then in our shrunk example, x = 10 rather than say 97. """ self.debug('Shrinking of individual blocks') i = 0 while i < len(self.blocks): u, v = self.blocks[i] minimize( self.shrink_target.buffer[u:v], lambda b: self.try_shrinking_blocks((i,), b), random=self.__engine.random, full=False, ) i += 1 def reorder_blocks(self): """Attempt to reorder blocks of the same size so that lexically larger values go later. This is mostly useful for canonicalization of examples. e.g. if we have x = data.draw(st.integers()) y = data.draw(st.integers()) assert x == y Then by minimizing x and y individually this could give us either x=0, y=1 or x=1, y=0. According to our sorting order, the former is a better example, but if in our initial draw y was zero then we will not get it. When this pass runs it will swap the values of x and y if that occurs. As well as canonicalization, this can also unblock other things. For example suppose we have n = data.draw(st.integers(0, 10)) ls = data.draw(st.lists(st.integers(), min_size=n, max_size=n)) assert len([x for x in ls if x != 0]) <= 1 We could end up with something like [1, 0, 0, 1] if we started from the wrong place. This pass would reorder this to [0, 0, 1, 1]. Shrinking n can then try to delete the lost bytes (see try_shrinking_blocks for how this works), taking us immediately to [1, 1]. This is a less important role for this pass, but still significant. """ self.debug('Reordering blocks') block_lengths = sorted(self.shrink_target.block_starts, reverse=True) for n in block_lengths: i = 1 while i < len(self.shrink_target.block_starts.get(n, ())): j = i while j > 0: buf = self.shrink_target.buffer blocks = self.shrink_target.block_starts[n] a_start = blocks[j - 1] b_start = blocks[j] a = buf[a_start:a_start + n] b = buf[b_start:b_start + n] if a <= b: break swapped = ( buf[:a_start] + b + buf[a_start + n:b_start] + a + buf[b_start + n:]) assert len(swapped) == len(buf) assert swapped < buf if self.incorporate_new_buffer(swapped): j -= 1 else: break i += 1 def interval_deletion_with_block_lowering(self): """This pass tries to delete each interval while replacing a block that precedes that interval with its immediate two lexicographical predecessors. We only do this for blocks that are marked as shrinking - that is, when we tried lowering them it resulted in a smaller example. This makes it important that this runs after minimize_individual_blocks (which populates those blocks). The reason for this pass is that it guarantees that we can delete elements of ls in the following scenario: n = data.draw(st.integers(0, 10)) ls = data.draw(st.lists(st.integers(), min_size=n, max_size=n)) Replacing the block for n with its predecessor replaces n with n - 1, and deleting a draw call in ls means that we draw exactly the desired n - 1 elements for this list. We actually also try replacing n with n - 2, as we will have intervals for adjacent pairs of draws and that ensures that those will find the right block lowering in this case too. This is necessarily a somewhat expensive pass - worst case scenario it tries len(blocks) * len(intervals) = O(len(buffer)^2 log(len(buffer))) shrinks, so it's important that it runs late in the process when the example size is small and most of the blocks that can be zeroed have been. """ self.debug('Lowering blocks while deleting intervals') i = 0 while i < len(self.intervals): u, v = self.intervals[i] changed = False # This loop never exits normally because the r >= u branch will # always trigger once we find a block inside the interval, hence # the pragma. for j, (r, s) in enumerate( # pragma: no branch self.blocks ): if r >= u: break if not self.is_shrinking_block(j): continue b = self.shrink_target.buffer[r:s] if any(b): n = int_from_bytes(b) for m in hrange(max(n - 2, 0), n): c = int_to_bytes(m, len(b)) attempt = bytearray(self.shrink_target.buffer) attempt[r:s] = c del attempt[u:v] if self.incorporate_new_buffer(attempt): changed = True break if changed: break if not changed: i += 1 def lower_dependent_block_pairs(self): """This is a fairly specific shrink pass that is mostly specialised for our integers strategy, though is probably useful in other places. It looks for adjacent pairs of blocks where lowering the value of the first changes the size of the second. Where this happens, we may lose interestingness because this takes the prefix rather than the suffix of the next block, so if lowering the block produces an uninteresting value and this change happens, we try replacing the second block with its suffix and shrink again. For example suppose we had: m = data.draw_bits(8) n = data.draw_bits(m) And initially we draw m = 9, n = 1. This gives us the bytes [9, 0, 1]. If we lower 9 to 8 then we now read [9, 0], because the block size of the n has changed. This pass allows us to also try [9, 1], which corresponds to m=8, n=1 as desired. This should *mostly* be handled by the minimize_individual_blocks pass, but that won't always work because its length heuristic can be wrong if the changes to the next block have knock on size changes, while this one triggers more reliably. """ self.debug('Lowering adjacent pairs of dependent blocks') i = 0 while i + 1 < len(self.blocks): u, v = self.blocks[i] i += 1 b = int_from_bytes(self.shrink_target.buffer[u:v]) if b > 0: attempt = bytearray(self.shrink_target.buffer) attempt[u:v] = int_to_bytes(b - 1, v - u) attempt = hbytes(attempt) shrunk = self.cached_test_function(attempt) if ( shrunk is not self.shrink_target and i < len(shrunk.blocks) and shrunk.blocks[i][1] < self.blocks[i][1] ): _, r = self.blocks[i] k = shrunk.blocks[i][1] - shrunk.blocks[i][0] buf = attempt[:v] + self.shrink_target.buffer[r - k:] self.incorporate_new_buffer(buf) def reorder_bytes(self): """This is a hyper-specific and moderately expensive shrink pass. It is designed to do similar things to reorder_blocks, but it works in cases where reorder_blocks may fail. The idea is that we expect to have a *lot* of single byte blocks, and they have very different meanings and interpretations. This means that the reasonably cheap approach of doing what is basically insertion sort on these blocks is unlikely to work. So instead we try to identify the subset of the single-byte blocks that we can freely move around and more aggressively put those into a sorted order. This is useful because e.g. we draw integers as single bytes, and if we don't have a pass like that then we're unable to shrink from [10, 0] to [0, 10]. In the event that we fail to do much sorting this is O(number of out of order pairs), which is O(n^2) in the worst case. In order to offset we try to do as much efficient sorting as possible to reduce the number of out of order pairs before we get to that stage. """ free_bytes = [] for i, (u, v) in enumerate(self.blocks): if ( v == u + 1 and u not in self.shrink_target.forced_indices ): free_bytes.append(u) if not free_bytes: return original = self.shrink_target def attempt(new_ordering): assert len(new_ordering) == len(free_bytes) assert len(self.shrink_target.buffer) == len(original.buffer) attempt = bytearray(self.shrink_target.buffer) for i, b in zip(free_bytes, new_ordering): attempt[i] = b return self.incorporate_new_buffer(attempt) ordering = [self.shrink_target.buffer[i] for i in free_bytes] if ordering == sorted(ordering): return if attempt(sorted(ordering)): return True n = len(ordering) # We now try to sort the "high bytes". The idea here is that high bytes # are more likely to be "payload" in some sense: Their value matters # mostly in relation to the other values. Additionally they are likely # to be moved around more in the reordering, so if we can get them # sorted up front we will save a lot of time later. # In order to do this we use binary search to find a value v such that # we can sort all values >= v. We do this in at most 8 steps (usually # less). # Invariant: We can sort the set of bytes which are >= hi, we can't # sort the set of bytes that are >= lo. # But see comment below about how these invariants may occasionally be # violated. lo = min(ordering) hi = max(ordering) while lo + 1 < hi: mid = (lo + hi) // 2 excessive = [i for i in hrange(n) if ordering[i] >= mid] trial = list(ordering) for i, b in zip(excessive, sorted(ordering[i] for i in excessive)): trial[i] = b if trial == ordering or attempt(trial): if ( len(self.shrink_target.buffer) != len(original.buffer) ): return # Technically this could result in us violating our invariants # if the bytes change too much. However if that happens the # loop is still useful so we carry on as if it didn't. ordering = [ self.shrink_target.buffer[i] for i in free_bytes] hi = mid else: lo = mid i = 1 while i < n: for k in hrange(i - 1, -1, -1): if ordering[k] <= ordering[i]: continue swapped = list(ordering) swapped[k], swapped[i] = swapped[i], swapped[k] if attempt(swapped): i = k if ( len(self.shrink_target.buffer) != len(original.buffer) ): return ordering = [ self.shrink_target.buffer[i] for i in free_bytes] break else: i += 1
Put your mouse over an expression. If we've saved any values for it, they'll go here.